from ActualCausal.Train.Passive.train_passive import train_passive
from ActualCausal.Train.Active.train_full_active import train_full_active
from ActualCausal.Train.Active.train_mask_active import train_masked_active
from ActualCausal.Train.Active.train_given_active import train_given_active
from ActualCausal.Train.Active.train_random_mask_active import train_random_active
from ActualCausal.Train.Active.train_mask_inter import train_mask_inter
from ActualCausal.Train.Inter.train_inter import train_inter
from ActualCausal.Train.Inter.train_trace import train_trace
from ActualCausal.Train.EM.train_em import train_EM
from ActualCausal.Train.Inter.train_null_bin import train_null_bin
from ActualCausal.Train.Inter.train_null_assign import train_null_assign
from ActualCausal.Train.Cluster.train_cluster_active import train_cluster_active
from ActualCausal.Train.Cluster.train_cluster_inter import train_cluster_inter
from tianshou.data import Batch
import time

def needed_additional(args, params):
    # decides which additional variables are needed for the computation
    # Possible values to add:
    # mask
    # pre_embeddings
    # embeddings
    # reduction
    # attn
    # TODO: add more components as needed
    needed_values = list()
    if args.inter.regularization.embedding.embed_reg_lambda > 0:
        needed_values.append("pre_embeddings_query")
    if args.inter.regularization.null_consistency.null_reg_lambda > 0:
        needed_values.append("pre_embeddings_query")
    if args.inter.regularization.embedding.mask_embed_reg_lambda > 0:
        needed_values.append("masked_pre_embeddings")
    if args.inter.regularization.attention.attn_reg_lambda > 0:
        needed_values.append("attn")
    return needed_values

def apply_train(train_form, result, params, model, args, train_buffer, additional, log_batch, wrap_function, i, intermediate_logger):
    '''
    runs a single train function by selecting the desired train form. 
    If the train form requires a particular factor, runs each factor in
    sequence
    Available train forms described in train_model 
    '''
    if train_form.find("given") != -1:
        given_str = train_form[len("given_"):]
        train_form_name = train_form
        train_form = "given"
    if "passive" == train_form:
        result.passive = train_passive(args, params, model, train_buffer, single=False, log_batch=log_batch, additional=additional, wrap_function=wrap_function, itr_num=i, intermediate_logger=intermediate_logger)
    if "single_passive" == train_form:
        if len(args.inter.train_names) > 0:
            result.passive = Batch()
            for name in args.inter.train_names:
                model.set_target_name(name)
                result.passive[name] = train_passive(args, params, model, train_buffer, single=True, name=name, log_batch=log_batch, additional=additional, itr_num=i, intermediate_logger=intermediate_logger)
    if "full" == train_form:
        if len(args.inter.train_names) > 0:
            result.full = Batch()
            for name in args.inter.train_names:
                model.set_target_name(name)
                result.full[name] = train_full_active(args, params, model, train_buffer, form="full", name=name, log_batch=log_batch, additional=additional, itr_num=i, intermediate_logger=intermediate_logger)
    if "pair" == train_form:
        if len(args.inter.pair_names) > 0:
            result.pair = Batch()
            for name in args.inter.pair_names:
                model.set_target_name(name)
                names = name.split('->')
                result.pair[name] = train_full_active(args, params, model, train_buffer, form="pair", name=names[-1], log_batch=log_batch, additional=additional, itr_num=i, intermediate_logger=intermediate_logger)
    if "all_full" == train_form: result.all_full = train_full_active(args, params, model, train_buffer, form="all", log_batch=log_batch, additional=additional, wrap_function=wrap_function, itr_num=i, intermediate_logger=intermediate_logger)
    if "mask" == train_form or "mask_both" == train_form: # TODO can't have mask_both and mask simultaniously
        if len(args.inter.train_names) > 0:
            result.mask = Batch()
            for name in args.inter.train_names:
                model.set_target_name(name)
                result.mask[name] = train_masked_active(args, params, model, train_buffer, form="full", name=name, log_batch=log_batch, additional=additional, both="mask_both" == train_form, itr_num=i, intermediate_logger=intermediate_logger)
    if "all_mask" == train_form or "all_mask_both" == train_form: result.all_mask = train_masked_active(args, params, model, train_buffer, log_batch=log_batch, additional=additional, wrap_function=wrap_function, both = "all_mask_both" == train_form, itr_num=i, intermediate_logger=intermediate_logger)
    if "mask_inter" == train_form: # TODO mask_inter_both is probably really easy to implement
        if len(args.inter.train_names) > 0:
            result.mask, result.inter = Batch(), Batch()
            for name in args.inter.train_names:
                model.set_target_name(name)
                result.mask[name], result.inter[name] = train_mask_inter(args, params, model, train_buffer, form="full", name=name, log_batch=log_batch, additional=additional, both="mask_inter_both" == train_form, itr_num=i, intermediate_logger=intermediate_logger)
    if "all_mask_inter" == train_form: result.all_mask_inter = train_mask_inter(args, params, model, train_buffer, log_batch=log_batch, additional=additional, wrap_function=wrap_function, both = "all_mask_inter_both" == train_form, itr_num=i, intermediate_logger=intermediate_logger)
    if "rand_mask" == train_form:
        if len(args.inter.train_names) > 0:
            result.rand_mask = Batch()
            for name in args.inter.train_names:
                model.set_target_name(name)
                result.rand_mask[name] = train_random_active(args, params, model, train_buffer, form="full", log_batch=log_batch, additional=additional, itr_num=i, intermediate_logger=intermediate_logger)
    if "all_rand_mask" == train_form: result.all_rand_mask = train_random_active(args, params, model, train_buffer, form="all", log_batch=log_batch, additional=additional, itr_num=i, intermediate_logger=intermediate_logger)
    if "given" == train_form:
        if len(args.inter.train_names) > 0:
            result[train_form_name] = Batch()
            for name in args.inter.train_names:
                model.set_target_name(name)
                result[train_form_name][name] = train_given_active(args, params, model, train_buffer, given_str, form="full", log_batch=log_batch, wrap_function=wrap_function, additional=additional, itr_num=i, intermediate_logger = intermediate_logger)
    if "binaries" == train_form:
        result.binaries = train_trace(args, params, model, train_buffer, log_batch=log_batch, additional=additional, itr_num=i, intermediate_logger=intermediate_logger)
    if "null_bin" == train_form:
        if len(args.inter.train_names) > 0:
            result.null_bin = Batch()
            for name in args.inter.train_names:
                model.set_target_name(name)
                result.null_bin[name] = train_null_bin(args, params, model, train_buffer, log_batch=log_batch, additional=additional, name=name, itr_num=i, intermediate_logger=intermediate_logger)
    if "null_assign" == train_form:
        if len(args.inter.train_names) > 0:
            result.null_assign = Batch()
            for name in args.inter.train_names:
                model.set_target_name(name)
                result.null_assign[name] = train_null_assign(args, params, model, train_buffer, log_batch=log_batch, additional=additional, name=name, itr_num=i, intermediate_logger=intermediate_logger)
    if "inter" == train_form or "inter_both" == train_form:
        if len(args.inter.train_names) > 0:
            result.inter = Batch()
            for name in args.inter.train_names:
                model.set_target_name(name)
                result.inter[name] = train_inter(args, params, model, train_buffer, form="full", name=name, log_batch=log_batch, additional=additional, both="inter_both" == train_form, itr_num=i, intermediate_logger=intermediate_logger)
    if "all_inter" == train_form or "all_inter_both" == train_form:
        result.all_inter = train_inter(args, params, model, train_buffer, form="all", log_batch=log_batch, additional=additional, wrap_function=wrap_function, both="all_inter_both" == train_form, itr_num=i, intermediate_logger=intermediate_logger)
    if "cluster_active" == train_form: # should NOT do both cluster and masked active or train_inter, since these share the same networks
        result.cluster_active = train_cluster_active(args, params, model, train_buffer, log_batch=log_batch, additional=additional, itr_num=i, intermediate_logger=intermediate_logger)
    if "cluster_inter" == train_form: # should NOT do both cluster and masked active or train_inter, since these share the same networks
        result.cluster_active = train_cluster_inter(args, params, model, train_buffer, log_batch=log_batch, additional=additional, itr_num=i, intermediate_logger=intermediate_logger)
    # if performing all training, the logic is currently exactly the same as full training
    # TODO: in the future this may differ, depending on if there is factor-based weighting
    if "em" == train_form:
        if len(args.inter.train_names) > 0:
            result.em = Batch()
            for name in args.inter.train_names:
                model.set_target_name(name)
                result.em[name] = train_EM(args, params, model, train_buffer, form="full", name=name, log_batch=log_batch, additional=additional, itr_num=i, intermediate_logger=intermediate_logger)
    if "all_em" == train_form:
        result.em = train_EM(args, params, model, train_buffer, form="all", log_batch=log_batch, additional=additional, itr_num=i, intermediate_logger=intermediate_logger)
    return result


def train_model(i, args, params, model, train_buffer, log_batch=[], wrap_function=None, intermediate_logger=None):
    # runs the desired training operation(s). The heavy lifting is in the called functions
    # note that any adaptive parameters must be computed internally
    # each of the called functions has the general structure:
    #   sample batch
    #   call model.infer with the desired operation
    #   call compute_likelihood with the appropriate values
    #   generate the loss
    #   get the optimizer and run it, assigning the gradients
    # log_batch keeps the names of values to store from the batch in the output
    # @param wrap function takes in a batch sampled from a buffer with different naming conventions for keys 
    #           or missing certain names, and converts it to one usable for training.
    #           The logic occurs in all train functions.
    # i: the current iteration number, used for logging
    # intermediate_logger: logging intermediate values
    '''
    Available train forms:
    passive: trains a unified passive model for all factors
    single_passive: trains a passive model for a particular name
    full: trains a model mapping all parents to a target without masking
    pair: trains a model mapping some number of parents to a single target
    all_full: trains a model mapping all parents to all targets without masking
    mask: all parents to a target with masking
    mask_both: all parents to a target with masking, optimizing the masking model
    all_mask: all parents to all targets with per-parent masking
    all_mask_both: all parents to all targets with per-parent masking, optimizing the masking model
    mask_inter: all-1 forward model with regularized masking model
    all_mask_inter: all-all forward model with regularized masking model
    rand_mask: trains an all-1 forward model with random masks
    all_rand_mask: trains a all-all forward model with random masks
    binaries: trains an inference model to match per-factor binaries
    null_bin: evaluates null inference values, then trains a model to match them
    null_assign: assigns keys null_traces, null_weights to the buffer using null inference
    inter: all-1 regularized masking
    all_inter: all-all regularized masking
    all_inter_both: all-all regularized masking and forward model (with the regularized objective)
    cluster_active: Assignment of interaction clusters, forward model training 
    cluster_inter: Assignment of interaction clusters, regularized mask model training
    em: all-1 alternate between forward and mask assignment 
    all_em: all-all alternate between forward and mask assignment
    '''
    result = Batch()
    additional = needed_additional(args, params)
    for train_form in args.inter.train_forms:
        # print("training", i, train_form, args.inter.train_forms)
        result = apply_train(train_form, result, params, model, args, train_buffer, additional, log_batch, wrap_function, i, intermediate_logger)
    return result